#***************************************************************************************/
#
#    Based on PointRCNN Library (MIT license):
#    https://github.com/sshaoshuai/PointRCNN
#
#    Copyright (c) 2019 Shaoshuai Shi

#    Permission is hereby granted, free of charge, to any person obtaining a copy
#    of this software and associated documentation files (the "Software"), to deal
#    in the Software without restriction, including without limitation the rights
#    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#    copies of the Software, and to permit persons to whom the Software is
#    furnished to do so, subject to the following conditions:

#    The above copyright notice and this permission notice shall be included in all
#    copies or substantial portions of the Software.

#    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#    SOFTWARE.
#
#***************************************************************************************/

from collections.abc import Iterable
from torch import nn

bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)


def split_bn_bias(layer_groups):
    """Split the layers in `layer_groups` into batchnorm (`bn_types`) and non-
    batchnorm groups.
    """
    split_groups = []
    for l in layer_groups:
        l1, l2 = [], []
        for c in l.children():
            if isinstance(c, bn_types):
                l2.append(c)
            else:
                l1.append(c)
        split_groups += [nn.Sequential(*l1), nn.Sequential(*l2)]
    return split_groups


def listify(p=None, q=None):
    """Make `p` listy and the same length as `q`."""
    if p is None:
        p = []
    elif isinstance(p, str):
        p = [p]
    elif not isinstance(p, Iterable):
        p = [p]
    n = q if type(q) == int else len(p) if q is None else len(q)
    if len(p) == 1:
        p = p * n
    assert len(p) == n, f'List len mismatch ({len(p)} vs {n})'
    return list(p)


def trainable_params(m: nn.Module):
    """Return list of trainable params in `m`."""
    res = filter(lambda p: p.requires_grad, m.parameters())
    return res


def is_tuple(x) -> bool:
    return isinstance(x, tuple)


# copy from fastai.
class OptimWrapper():
    """Basic wrapper around `opt` to simplify hyper-parameters changes."""

    def __init__(self, opt, wd, true_wd: bool = False, bn_wd: bool = True):
        self.opt, self.true_wd, self.bn_wd = opt, true_wd, bn_wd
        self.opt_keys = list(self.opt.param_groups[0].keys())
        self.opt_keys.remove('params')
        self.read_defaults()
        self.wd = wd

    @classmethod
    def create(cls, opt_func, lr, layer_groups, **kwargs):
        """Create an `optim.Optimizer` from `opt_func` with `lr`.

        Set lr on `layer_groups`.
        """
        split_groups = split_bn_bias(layer_groups)
        opt = opt_func([{
            'params': trainable_params(l),
            'lr': 0
        } for l in split_groups])
        opt = cls(opt, **kwargs)
        opt.lr, opt.opt_func = listify(lr, layer_groups), opt_func
        return opt

    def new(self, layer_groups):
        """Create a new `OptimWrapper` from `self` with another `layer_groups`
        but the same hyper-parameters.
        """
        opt_func = getattr(self, 'opt_func', self.opt.__class__)
        return self.create(opt_func,
                           self.lr,
                           layer_groups,
                           wd=self.wd,
                           true_wd=self.true_wd,
                           bn_wd=self.bn_wd)

    def __repr__(self) -> str:
        return f'OptimWrapper over {repr(self.opt)}.\nTrue weight decay: {self.true_wd}'

    # Pytorch optimizer methods
    def step(self) -> None:
        """Set weight decay and step optimizer."""
        # weight decay outside of optimizer step (AdamW)
        if self.true_wd:
            for lr, wd, pg1, pg2 in zip(self._lr, self._wd,
                                        self.opt.param_groups[::2],
                                        self.opt.param_groups[1::2]):
                for p in pg1['params']:
                    # When some parameters are fixed:  Shaoshuai Shi
                    if p.requires_grad is False:
                        continue
                    p.data.mul_(1 - wd * lr)
                if self.bn_wd:
                    for p in pg2['params']:
                        # When some parameters are fixed:  Shaoshuai Shi
                        if p.requires_grad is False:
                            continue
                        p.data.mul_(1 - wd * lr)
            self.set_val('weight_decay', listify(0, self._wd))
        self.opt.step()

    def zero_grad(self) -> None:
        """Clear optimizer gradients."""
        self.opt.zero_grad()

    # Passthrough to the inner opt.
    def __getattr__(self, k: str):
        return getattr(self.opt, k, None)

    def clear(self):
        """Reset the state of the inner optimizer."""
        sd = self.state_dict()
        sd['state'] = {}
        self.load_state_dict(sd)

    # Hyperparameters as properties
    @property
    def lr(self) -> float:
        return self._lr[-1]

    @lr.setter
    def lr(self, val: float) -> None:
        self._lr = self.set_val('lr', listify(val, self._lr))

    @property
    def mom(self) -> float:
        return self._mom[-1]

    @mom.setter
    def mom(self, val: float) -> None:
        if 'momentum' in self.opt_keys:
            self.set_val('momentum', listify(val, self._mom))
        elif 'betas' in self.opt_keys:
            self.set_val('betas', (listify(val, self._mom), self._beta))
        self._mom = listify(val, self._mom)

    @property
    def beta(self) -> float:
        return None if self._beta is None else self._beta[-1]

    @beta.setter
    def beta(self, val: float) -> None:
        """Set beta, or alpha as makes sense, for given optimizer."""
        if val is None:
            return
        if 'betas' in self.opt_keys:
            self.set_val('betas', (self._mom, listify(val, self._beta)))
        elif 'alpha' in self.opt_keys:
            self.set_val('alpha', listify(val, self._beta))
        self._beta = listify(val, self._beta)

    @property
    def wd(self) -> float:
        return self._wd[-1]

    @wd.setter
    def wd(self, val: float) -> None:
        """Set weight decay."""
        if not self.true_wd:
            self.set_val('weight_decay',
                         listify(val, self._wd),
                         bn_groups=self.bn_wd)
        self._wd = listify(val, self._wd)

    # Helper functions
    def read_defaults(self) -> None:
        """Read the values inside the optimizer for the hyper-parameters."""
        self._beta = None
        if 'lr' in self.opt_keys:
            self._lr = self.read_val('lr')
        if 'momentum' in self.opt_keys:
            self._mom = self.read_val('momentum')
        if 'alpha' in self.opt_keys:
            self._beta = self.read_val('alpha')
        if 'betas' in self.opt_keys:
            self._mom, self._beta = self.read_val('betas')
        if 'weight_decay' in self.opt_keys:
            self._wd = self.read_val('weight_decay')

    def set_val(self, key: str, val, bn_groups: bool = True):
        """Set `val` inside the optimizer dictionary at `key`."""
        if is_tuple(val):
            val = [(v1, v2) for v1, v2 in zip(*val)]
        for v, pg1, pg2 in zip(val, self.opt.param_groups[::2],
                               self.opt.param_groups[1::2]):
            pg1[key] = v
            if bn_groups:
                pg2[key] = v
        return val

    def read_val(self, key: str):
        """Read a hyperparameter `key` in the optimizer dictionary."""
        val = [pg[key] for pg in self.opt.param_groups[::2]]
        if is_tuple(val[0]):
            val = [o[0] for o in val], [o[1] for o in val]
        return val
